"""
Phase 5: Bulletproof Robustness — RER
=======================================
Additional tests to make the Groneck-Kaufmann partial refutation airtight:
  (a) Panel cointegration tests (Kao, Pedroni-style)
  (b) Cluster-robust standard errors (Driscoll-Kraay style via block bootstrap)
  (c) Placebo/permutation test (randomly reassign Z across countries)
  (d) G-K exact specification: Z → service sector share (non-tradable proxy)
  (e) Leave-one-out: drop each country, check stability
  (f) Hausman-style: FE vs RE comparison
  (g) Jackknife over regions
  (h) Alternative demographic measures (median_age, working_age_share alone)

Output: output/tables/cointegration.md, placebo.md, bulletproof_robustness.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/rer")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label, silent=False):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 30:
        if not silent:
            print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None
    gls = PanelGLS()
    try:
        gls.fit(sub[dep_var].values, sub[regressors].values,
                sub['iso3'].values, sub['year'].values)
    except Exception as e:
        if not silent:
            print(f"  [{label}] Model failed: {e}")
        return None
    if not silent:
        print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, R²={gls.r_squared:.4f}")
        for i, name in enumerate(regressors):
            sig = stars(gls.pvalues[i])
            print(f"    {name:<30} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
    return results


def main():
    print("=" * 70)
    print("PHASE 5: Bulletproof Robustness — RER")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "rer_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls_bs = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_gdp_pc']
    controls_bs = [c for c in controls_bs if c in df.columns]

    # ═══════════════════════════════════════════════════════════════════
    # PART A: PANEL COINTEGRATION TESTS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: PANEL COINTEGRATION TESTS")
    print("=" * 70)

    # Kao (1999) test: pooled ADF on residuals from panel regression
    # H0: no cointegration; reject = evidence of cointegration
    est = df.dropna(subset=['log_reer_combined', 'Z_1']).copy()
    est = est.sort_values(['iso3', 'year'])

    # Country-by-country OLS residuals
    adf_stats = []
    countries = est['iso3'].unique()
    for iso3 in countries:
        c = est[est['iso3'] == iso3].copy()
        if len(c) < 8:
            continue
        # Simple OLS: log_reer = a + b*Z_1 + e
        y = c['log_reer_combined'].values
        x = np.column_stack([np.ones(len(c)), c['Z_1'].values])
        try:
            beta = np.linalg.lstsq(x, y, rcond=None)[0]
        except Exception:
            continue
        resid = y - x @ beta
        # ADF(0) on residuals: Δe_t = ρ*e_{t-1} + u_t
        if len(resid) < 4:
            continue
        de = np.diff(resid)
        e_lag = resid[:-1]
        if np.std(e_lag) < 1e-10:
            continue
        try:
            rho_hat = np.sum(de * e_lag) / np.sum(e_lag ** 2)
            se_rho = np.sqrt(np.sum((de - rho_hat * e_lag) ** 2) / (len(de) - 1) / np.sum(e_lag ** 2))
            t_stat = rho_hat / se_rho
            adf_stats.append(t_stat)
        except Exception:
            continue

    if adf_stats:
        # Kao pooled t-statistic
        pooled_t = np.mean(adf_stats) * np.sqrt(len(adf_stats))
        p_kao = 2 * stats.norm.cdf(-abs(pooled_t))
        reject_pct = np.mean([1 for t in adf_stats if t < -2.86]) * 100

        print(f"\n  Kao-style panel cointegration test (log_reer ~ Z₁):")
        print(f"    Countries tested: {len(adf_stats)}")
        print(f"    Mean ADF t-stat: {np.mean(adf_stats):.3f}")
        print(f"    Pooled t-stat: {pooled_t:.3f}")
        print(f"    p-value: {p_kao:.4f}")
        print(f"    Country-level reject at 5%: {reject_pct:.1f}%")
        if p_kao < 0.05:
            print("    → REJECT no cointegration — level regression is well-specified")
        else:
            print("    → Cannot reject no cointegration — caution on level regression")

    # Repeat for multivariate: log_reer ~ Z_1 + Z_2 + Z_3 + log_gdp_pc
    adf_stats_mv = []
    for iso3 in countries:
        c = est[est['iso3'] == iso3].copy()
        cols = ['log_reer_combined', 'Z_1', 'Z_2', 'Z_3']
        if 'log_gdp_pc' in c.columns:
            cols.append('log_gdp_pc')
        c2 = c.dropna(subset=cols)
        if len(c2) < 8:
            continue
        y = c2['log_reer_combined'].values
        xvars = [c2[v].values for v in cols[1:]]
        x = np.column_stack([np.ones(len(c2))] + xvars)
        try:
            beta = np.linalg.lstsq(x, y, rcond=None)[0]
        except Exception:
            continue
        resid = y - x @ beta
        de = np.diff(resid)
        e_lag = resid[:-1]
        if len(de) < 3 or np.std(e_lag) < 1e-10:
            continue
        try:
            rho_hat = np.sum(de * e_lag) / np.sum(e_lag ** 2)
            se_rho = np.sqrt(np.sum((de - rho_hat * e_lag) ** 2) / (len(de) - 1) / np.sum(e_lag ** 2))
            t_stat = rho_hat / se_rho
            adf_stats_mv.append(t_stat)
        except Exception:
            continue

    if adf_stats_mv:
        pooled_t_mv = np.mean(adf_stats_mv) * np.sqrt(len(adf_stats_mv))
        p_kao_mv = 2 * stats.norm.cdf(-abs(pooled_t_mv))
        reject_pct_mv = np.mean([1 for t in adf_stats_mv if t < -2.86]) * 100

        print(f"\n  Multivariate cointegration (log_reer ~ Z₁ + Z₂ + Z₃ + log_gdp_pc):")
        print(f"    Countries tested: {len(adf_stats_mv)}")
        print(f"    Pooled t-stat: {pooled_t_mv:.3f}")
        print(f"    p-value: {p_kao_mv:.4f}")
        print(f"    Country-level reject at 5%: {reject_pct_mv:.1f}%")

    # Write cointegration table
    md = ["# Panel Cointegration Tests\n"]
    md.append("| Test | Variables | N countries | Pooled t | p-value | Reject rate (5%) |")
    md.append("|---|---|---|---|---|---|")
    if adf_stats:
        md.append(f"| Kao bivariate | log_reer ~ Z₁ | {len(adf_stats)} | {pooled_t:.3f} "
                  f"| {p_kao:.4f} | {reject_pct:.1f}% |")
    if adf_stats_mv:
        md.append(f"| Kao multivariate | log_reer ~ Z₁+Z₂+Z₃+gdp_pc | {len(adf_stats_mv)} "
                  f"| {pooled_t_mv:.3f} | {p_kao_mv:.4f} | {reject_pct_mv:.1f}% |")
    md.append("\n*H₀: no cointegration. Reject → level regression well-specified.*")
    out = TABLES_DIR / "cointegration.md"
    out.write_text('\n'.join(md))
    print(f"\n  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: CLUSTER-ROBUST STANDARD ERRORS (Block Bootstrap)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: CLUSTER-ROBUST STANDARD ERRORS (Country Bootstrap)")
    print("=" * 70)

    all_vars = demo_vars + controls_bs
    est = df.dropna(subset=['log_reer_combined'] + all_vars).copy()
    country_list = est['iso3'].unique()
    n_countries = len(country_list)

    # Run baseline for comparison
    gls_base = PanelGLS()
    gls_base.fit(est['log_reer_combined'].values, est[all_vars].values,
                 est['iso3'].values, est['year'].values)
    base_beta = gls_base.beta.copy()
    base_se = gls_base.se.copy()
    base_p = gls_base.pvalues.copy()

    print(f"\n  Baseline PanelGLS (N={gls_base.n_obs}, {n_countries} countries):")
    for i, name in enumerate(all_vars):
        print(f"    {name:<25} {base_beta[i]:>10.4f} (SE={base_se[i]:.4f}, p={base_p[i]:.4f})")

    # Country-cluster bootstrap (500 iterations)
    n_boot = 500
    np.random.seed(42)
    boot_betas = []

    print(f"\n  Running {n_boot} bootstrap iterations (cluster by country)...")
    for b in range(n_boot):
        # Resample countries with replacement
        boot_countries = np.random.choice(country_list, size=n_countries, replace=True)
        # Build bootstrap sample (allowing duplicate countries with unique IDs)
        boot_dfs = []
        for j, c in enumerate(boot_countries):
            cdf = est[est['iso3'] == c].copy()
            cdf['iso3'] = f"{c}_{j}"  # unique ID for duplicates
            boot_dfs.append(cdf)
        boot_df = pd.concat(boot_dfs, ignore_index=True)

        gls_boot = PanelGLS()
        try:
            gls_boot.fit(boot_df['log_reer_combined'].values,
                         boot_df[all_vars].values,
                         boot_df['iso3'].values,
                         boot_df['year'].values)
            boot_betas.append(gls_boot.beta)
        except Exception:
            continue

    if boot_betas:
        boot_arr = np.array(boot_betas)
        boot_se = np.std(boot_arr, axis=0)
        boot_z = base_beta / boot_se
        boot_p = 2 * (1 - stats.norm.cdf(np.abs(boot_z)))

        print(f"\n  Bootstrap results ({len(boot_betas)} successful iterations):")
        print(f"  {'Variable':<25} {'Coef':>10} {'PanelGLS SE':>12} {'Boot SE':>10} {'Boot p':>10} {'Sig':>4}")
        for i, name in enumerate(all_vars):
            sig = stars(boot_p[i])
            print(f"  {name:<25} {base_beta[i]:>10.4f} {base_se[i]:>12.4f} "
                  f"{boot_se[i]:>10.4f} {boot_p[i]:>10.4f} {sig}")

        # Key comparison
        z1_idx = all_vars.index('Z_1')
        print(f"\n  ★ Z₁: PanelGLS p={base_p[z1_idx]:.4f}, Bootstrap p={boot_p[z1_idx]:.4f}")
        if boot_p[z1_idx] < 0.05:
            print("    → Z₁ SURVIVES cluster-robust bootstrap at 5%")
        elif boot_p[z1_idx] < 0.10:
            print("    → Z₁ survives at 10% but not 5% — weaker than PanelGLS suggests")
        else:
            print("    → Z₁ DOES NOT survive bootstrap — significance may be inflated")

        # Save bootstrap table
        md = ["# Cluster-Robust Standard Errors (Country Bootstrap)\n"]
        md.append("| Variable | Coef | PanelGLS SE | PanelGLS p | Boot SE | Boot p | Sig |")
        md.append("|---|---|---|---|---|---|---|")
        for i, name in enumerate(all_vars):
            sig = stars(boot_p[i])
            md.append(f"| {name} | {base_beta[i]:.4f} | {base_se[i]:.4f} | {base_p[i]:.4f} "
                      f"| {boot_se[i]:.4f} | {boot_p[i]:.4f} | {sig} |")
        md.append(f"\n*{len(boot_betas)} bootstrap iterations, resampling {n_countries} countries with replacement.*")
        out = TABLES_DIR / "bootstrap_se.md"
        out.write_text('\n'.join(md))
        print(f"  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: PLACEBO / PERMUTATION TEST
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: PLACEBO TEST (Random Z Reassignment)")
    print("=" * 70)

    n_perm = 500
    np.random.seed(123)
    placebo_z1 = []

    est_perm = est.copy()
    print(f"  Running {n_perm} permutations (shuffling Z across countries within years)...")

    for p_iter in range(n_perm):
        # Shuffle Z_1, Z_2, Z_3 across countries within each year
        est_shuffled = est_perm.copy()
        for yr in est_shuffled['year'].unique():
            mask = est_shuffled['year'] == yr
            idx = est_shuffled.index[mask]
            for zv in demo_vars:
                vals = est_shuffled.loc[idx, zv].values.copy()
                np.random.shuffle(vals)
                est_shuffled.loc[idx, zv] = vals

        gls_perm = PanelGLS()
        try:
            gls_perm.fit(est_shuffled['log_reer_combined'].values,
                         est_shuffled[all_vars].values,
                         est_shuffled['iso3'].values,
                         est_shuffled['year'].values)
            z1_idx = all_vars.index('Z_1')
            placebo_z1.append(gls_perm.beta[z1_idx])
        except Exception:
            continue

    if placebo_z1:
        placebo_arr = np.array(placebo_z1)
        actual_z1 = base_beta[all_vars.index('Z_1')]

        # Two-sided p-value: fraction of placebo |Z₁| >= |actual Z₁|
        perm_p = np.mean(np.abs(placebo_arr) >= np.abs(actual_z1))
        placebo_mean = np.mean(placebo_arr)
        placebo_sd = np.std(placebo_arr)

        print(f"\n  Placebo results ({len(placebo_z1)} successful permutations):")
        print(f"    Actual Z₁ coefficient: {actual_z1:.4f}")
        print(f"    Placebo mean: {placebo_mean:.4f}")
        print(f"    Placebo SD: {placebo_sd:.4f}")
        print(f"    Permutation p-value: {perm_p:.4f}")
        print(f"    Actual / Placebo SD: {abs(actual_z1) / placebo_sd:.1f} SD from placebo mean")

        if perm_p < 0.01:
            print("    → STRONG: actual Z₁ is in the extreme tail of the placebo distribution")
        elif perm_p < 0.05:
            print("    → Actual Z₁ survives placebo test at 5%")
        else:
            print("    → WARNING: actual Z₁ does not clearly separate from placebo")

        md = ["# Placebo Test — Random Z Reassignment\n"]
        md.append(f"- Permutations: {len(placebo_z1)}")
        md.append(f"- Actual Z₁ coefficient: {actual_z1:.4f}")
        md.append(f"- Placebo distribution: mean={placebo_mean:.4f}, SD={placebo_sd:.4f}")
        md.append(f"- Permutation p-value: {perm_p:.4f}")
        md.append(f"- Distance: {abs(actual_z1) / placebo_sd:.1f} SD from placebo mean")
        md.append(f"\n*Z₁, Z₂, Z₃ shuffled across countries within each year. "
                  f"Tests whether the result could arise from any arbitrary country ranking.*")
        out = TABLES_DIR / "placebo.md"
        out.write_text('\n'.join(md))
        print(f"  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: G-K SPECIFICATION — Service sector share
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: SERVICE SECTOR / NON-TRADABLE PROXIES")
    print("=" * 70)

    gk_results = []

    # health_exp_gdp as non-tradable proxy (already tested in Phase 3)
    # Let's also try: Z → trade_openness (inverse non-tradable proxy)
    if 'trade_openness' in df.columns:
        r = run_model(df, 'trade_openness', demo_vars + ['rgdp_growth', 'log_gdp_pc'],
                      "D1: Z → trade openness (inverse NT)")
        if r: gk_results.append(r)

    # Z → health_exp (non-tradable demand)
    if 'health_exp_gdp' in df.columns:
        r = run_model(df, 'health_exp_gdp', demo_vars + ['rgdp_growth', 'log_gdp_pc'],
                      "D2: Z → health_exp (NT demand)")
        if r: gk_results.append(r)

        # OECD only: does the non-tradable channel work in OECD?
        oecd = df[df['oecd'] == 1].copy()
        r = run_model(oecd, 'health_exp_gdp', demo_vars + ['rgdp_growth', 'log_gdp_pc'],
                      "D3: OECD Z → health_exp")
        if r: gk_results.append(r)

        non_oecd = df[df['oecd'] == 0].copy()
        r = run_model(non_oecd, 'health_exp_gdp', demo_vars + ['rgdp_growth', 'log_gdp_pc'],
                      "D4: non-OECD Z → health_exp")
        if r: gk_results.append(r)

    # health → REER (OECD vs non-OECD)
    if 'health_exp_gdp' in df.columns:
        r = run_model(oecd, 'log_reer_combined',
                      ['health_exp_gdp'] + controls_bs,
                      "D5: OECD health → REER")
        if r: gk_results.append(r)

        r = run_model(non_oecd, 'log_reer_combined',
                      ['health_exp_gdp'] + controls_bs,
                      "D6: non-OECD health → REER")
        if r: gk_results.append(r)

    if gk_results:
        md = ["# Service Sector / Non-Tradable Channel Tests\n"]
        md.append("| Model | Dep Var | N | Countries | R² |")
        md.append("|---|---|---|---|---|")
        for r in gk_results:
            md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                      f"| {r['n_countries']} | {r['r_squared']:.3f} |")
        md.append("\n## Key Coefficients\n")
        md.append("| Model | Variable | Coef | SE | p-value | Sig |")
        md.append("|---|---|---|---|---|---|")
        for r in gk_results:
            for var in demo_vars + ['health_exp_gdp']:
                ckey = f'coef_{var}'
                if ckey in r:
                    p = r[f'p_{var}']
                    md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                              f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
        out = TABLES_DIR / "nontradable_gk.md"
        out.write_text('\n'.join(md))
        print(f"\n  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART E: LEAVE-ONE-OUT STABILITY
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART E: LEAVE-ONE-OUT STABILITY")
    print("=" * 70)

    est_loo = df.dropna(subset=['log_reer_combined'] + all_vars).copy()
    loo_countries = est_loo['iso3'].unique()
    loo_results = []

    for drop_iso in loo_countries:
        sub = est_loo[est_loo['iso3'] != drop_iso]
        gls_loo = PanelGLS()
        try:
            gls_loo.fit(sub['log_reer_combined'].values, sub[all_vars].values,
                        sub['iso3'].values, sub['year'].values)
            z1_idx = all_vars.index('Z_1')
            loo_results.append({
                'dropped': drop_iso,
                'Z_1_coef': gls_loo.beta[z1_idx],
                'Z_1_p': gls_loo.pvalues[z1_idx],
                'n_obs': gls_loo.n_obs,
            })
        except Exception:
            continue

    if loo_results:
        loo_df = pd.DataFrame(loo_results)
        z1_coefs = loo_df['Z_1_coef']
        z1_range = z1_coefs.max() - z1_coefs.min()
        always_sig = (loo_df['Z_1_p'] < 0.05).all()
        always_negative = (loo_df['Z_1_coef'] < 0).all()
        n_sig = (loo_df['Z_1_p'] < 0.05).sum()

        print(f"\n  Leave-one-out results ({len(loo_results)} countries):")
        print(f"    Z₁ range: [{z1_coefs.min():.4f}, {z1_coefs.max():.4f}]")
        print(f"    Z₁ mean: {z1_coefs.mean():.4f}")
        print(f"    Z₁ SD across LOO: {z1_coefs.std():.4f}")
        print(f"    Always significant at 5%: {'YES' if always_sig else 'NO'}")
        print(f"    Always negative: {'YES' if always_negative else 'NO'}")
        print(f"    Significant in {n_sig}/{len(loo_results)} iterations")

        # Flag any influential countries
        outlier_threshold = base_beta[all_vars.index('Z_1')] * 0.5  # >50% change
        influential = loo_df[np.abs(loo_df['Z_1_coef'] - base_beta[all_vars.index('Z_1')]) >
                              np.abs(outlier_threshold)]
        if len(influential) > 0:
            print(f"\n  Influential countries (>50% coefficient change):")
            for _, row in influential.iterrows():
                print(f"    Drop {row['dropped']}: Z₁={row['Z_1_coef']:.4f} (p={row['Z_1_p']:.4f})")

        md = ["# Leave-One-Out Stability\n"]
        md.append(f"- Baseline Z₁ = {base_beta[all_vars.index('Z_1')]:.4f}")
        md.append(f"- LOO Z₁ range: [{z1_coefs.min():.4f}, {z1_coefs.max():.4f}]")
        md.append(f"- Always significant at 5%: {'Yes' if always_sig else 'No'} "
                  f"({n_sig}/{len(loo_results)})")
        md.append(f"- Always negative: {'Yes' if always_negative else 'No'}")
        if len(influential) > 0:
            md.append(f"\n## Influential Countries (>50% coefficient change)")
            md.append("| Dropped | Z₁ Coef | p-value |")
            md.append("|---|---|---|")
            for _, row in influential.iterrows():
                md.append(f"| {row['dropped']} | {row['Z_1_coef']:.4f} | {row['Z_1_p']:.4f} |")
        out = TABLES_DIR / "leave_one_out.md"
        out.write_text('\n'.join(md))
        print(f"  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART F: JACKKNIFE OVER REGIONS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART F: JACKKNIFE OVER REGIONS")
    print("=" * 70)

    # Define regions
    regions = {
        'East Asia': ['CHN', 'JPN', 'KOR', 'TWN', 'HKG', 'SGP', 'THA', 'MYS',
                       'IDN', 'PHL', 'VNM', 'MMR', 'KHM', 'LAO'],
        'South Asia': ['IND', 'BGD', 'PAK', 'LKA', 'NPL'],
        'Sub-Saharan Africa': ['NGA', 'ZAF', 'KEN', 'GHA', 'ETH', 'TZA', 'UGA',
                                'SEN', 'CIV', 'CMR', 'MOZ', 'ZMB', 'BWA', 'MUS',
                                'NAM', 'RWA', 'MDG'],
        'Latin America': ['BRA', 'MEX', 'ARG', 'COL', 'CHL', 'PER', 'ECU', 'BOL',
                          'URY', 'PRY', 'VEN', 'CRI', 'PAN', 'DOM', 'GTM', 'HND',
                          'SLV', 'NIC', 'JAM', 'TTO'],
        'MENA': ['SAU', 'ARE', 'QAT', 'KWT', 'BHR', 'OMN', 'EGY', 'MAR', 'TUN',
                 'JOR', 'LBN', 'IRQ', 'IRN', 'ISR', 'DZA'],
        'Western Europe': ['DEU', 'FRA', 'GBR', 'ITA', 'ESP', 'NLD', 'BEL', 'AUT',
                           'CHE', 'SWE', 'NOR', 'DNK', 'FIN', 'IRL', 'PRT', 'GRC',
                           'LUX'],
        'Eastern Europe': ['POL', 'CZE', 'HUN', 'ROU', 'BGR', 'HRV', 'SVK', 'SVN',
                           'EST', 'LVA', 'LTU', 'SRB', 'UKR', 'BLR', 'RUS'],
        'Anglo': ['USA', 'CAN', 'AUS', 'NZL'],
    }

    jack_results = []
    for region_name, region_isos in regions.items():
        sub = df[~df['iso3'].isin(region_isos)].copy()
        r = run_model(sub, 'log_reer_combined', all_vars,
                      f"Drop {region_name}", silent=True)
        if r:
            z1_idx_name = 'Z_1'
            jack_results.append({
                'dropped_region': region_name,
                'n_dropped': len(region_isos),
                'Z_1_coef': r.get(f'coef_{z1_idx_name}', np.nan),
                'Z_1_p': r.get(f'p_{z1_idx_name}', np.nan),
                'n_obs': r['n_obs'],
                'n_countries': r['n_countries'],
            })

    if jack_results:
        print(f"\n  {'Region dropped':<25} {'Z₁':>10} {'p':>8} {'Sig':>4} {'N':>6} {'Countries':>10}")
        for j in jack_results:
            sig = stars(j['Z_1_p'])
            print(f"  {j['dropped_region']:<25} {j['Z_1_coef']:>10.4f} {j['Z_1_p']:>8.4f} "
                  f"{sig:>4} {j['n_obs']:>6} {j['n_countries']:>10}")

        md = ["# Regional Jackknife\n"]
        md.append("| Region Dropped | N dropped | Z₁ Coef | p-value | Sig | N obs | Countries |")
        md.append("|---|---|---|---|---|---|---|")
        for j in jack_results:
            sig = stars(j['Z_1_p'])
            md.append(f"| {j['dropped_region']} | {j['n_dropped']} | {j['Z_1_coef']:.4f} "
                      f"| {j['Z_1_p']:.4f} | {sig} | {j['n_obs']} | {j['n_countries']} |")
        out = TABLES_DIR / "regional_jackknife.md"
        out.write_text('\n'.join(md))
        print(f"\n  Saved: {out}")

    # ═══════════════════════════════════════════════════════════════════
    # PART G: ALTERNATIVE DEMOGRAPHIC MEASURES
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART G: ALTERNATIVE DEMOGRAPHIC MEASURES")
    print("=" * 70)

    alt_results = []

    # G1: OADR only
    r = run_model(df, 'log_reer_combined', ['old_dep'] + controls_bs,
                  "G1: OADR only")
    if r: alt_results.append(r)

    # G2: Youth dep only
    r = run_model(df, 'log_reer_combined', ['youth_dep'] + controls_bs,
                  "G2: youth_dep only")
    if r: alt_results.append(r)

    # G3: Working-age share
    if 'working_age_share' in df.columns:
        r = run_model(df, 'log_reer_combined', ['working_age_share'] + controls_bs,
                      "G3: working_age_share")
        if r: alt_results.append(r)

    # G4: Total dependency ratio
    if 'total_dep' in df.columns:
        r = run_model(df, 'log_reer_combined', ['total_dep'] + controls_bs,
                      "G4: total_dep")
        if r: alt_results.append(r)

    # G5: Life expectancy
    if 'life_expectancy' in df.columns:
        r = run_model(df, 'log_reer_combined', ['life_expectancy'] + controls_bs,
                      "G5: life_expectancy")
        if r: alt_results.append(r)

    # G6: Life expectancy + LE²
    if 'life_expectancy' in df.columns and 'life_expectancy_sq' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      ['life_expectancy', 'life_expectancy_sq'] + controls_bs,
                      "G6: LE + LE²")
        if r: alt_results.append(r)

    key_alt_vars = ['old_dep', 'youth_dep', 'working_age_share', 'total_dep',
                    'life_expectancy', 'life_expectancy_sq', 'Z_1']
    if alt_results:
        md = ["# Alternative Demographic Measures\n"]
        md.append("| Model | Dep Var | N | Countries | R² |")
        md.append("|---|---|---|---|---|")
        for r in alt_results:
            md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                      f"| {r['n_countries']} | {r['r_squared']:.3f} |")
        md.append("\n## Key Coefficients\n")
        md.append("| Model | Variable | Coef | SE | p-value | Sig |")
        md.append("|---|---|---|---|---|---|")
        for r in alt_results:
            for var in key_alt_vars:
                ckey = f'coef_{var}'
                if ckey in r:
                    p = r[f'p_{var}']
                    md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                              f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
        out = TABLES_DIR / "alt_demographics.md"
        out.write_text('\n'.join(md))
        print(f"\n  Saved: {out}")

    print("\n" + "=" * 70)
    print("Phase 5 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
